import numpy as np

from sklearn import metrics

def ensemble_block_kernel(X, Y, kernel=lambda x, y: metrics.pairwise.rbf_kernel(x, y, gamma=None)):
    """Compute the block kernel between X and Y.

    Parameters
    ----------
    X : ndarray of shape (n_ensemble, n_samples, n_features)

    Y : ndarray of shape (n_ensemble, n_samples, n_features)

    Returns
    -------
    kernel_block__matrix : ndarray of shape (n_ensemble * n_samples, n_ensemble * n_samples)
    """
    assert Y.shape == X.shape
    n_ens, n_samp, n_feat = X.shape
    # reshapes with index orders [11, 21, ..., n1, 12, 22, ...]
    # where n = n_samples
    X_flat = X.reshape(n_ens*n_samp, n_feat)
    Y_flat = Y.reshape(n_ens*n_samp, n_feat)
    kXY = kernel(X_flat, Y_flat)
    return kXY

def get_diag_blocks(kernel_block_matrix, n_ensemble):
    """Select the entries in the n x n kernel matrices K_ii of the (n*m)x(n*m) kernel block matrix K with i=1...m,
    i.e. sum({K_ii}_jt), where n = n_samples and m = n_ensemble

    Parameters
    ----------
    kernel_block_matrix : ndarray of shape (n_ensemble * n_sample, n_ensemble * n_sample)
    """
    n_all = kernel_block_matrix.shape[0]
    n_samples = int(n_all / n_ensemble)
    return np.array([
        # select n x n matrix K_ii
        kernel_block_matrix[i*n_samples:(i+1)*n_samples, i*n_samples:(i+1)*n_samples] for i in range(n_ensemble)
    ])

def avg_off_diag_blocks(kXY, diag_blocks):
    n_samp = diag_blocks[0].shape[0]
    n_ens = len(diag_blocks)
    return (kXY.sum()-diag_blocks.sum())/(n_samp*n_samp*n_ens*(n_ens-1))

def dist_cov(X, Y, kernel=lambda x, y: metrics.pairwise.rbf_kernel(x, y, gamma=None)):
    """
    Distributional covariance generated by kernel between samples X and Y

    Parameters
    ----------
    X : ndarray of shape (n_ensemble, n_samples, n_features)

    Y : ndarray of shape (n_ensemble, n_samples, n_features)
    """
    n_ens, n_samp, _ = X.shape
    kXY = ensemble_block_kernel(X, Y, kernel=kernel)
    diag_blocks = get_diag_blocks(kXY, n_ens)
    return np.mean(diag_blocks) - avg_off_diag_blocks(kXY, diag_blocks)

def dist_var(X, kernel=lambda x, y: metrics.pairwise.rbf_kernel(x, y, gamma=None)):
    """
    Distributional variance generated by kernel and estimated on samples X

    Parameters
    ----------
    X : ndarray of shape (n_ensemble, n_samples, n_features)
    """
    n_ens, n_samp, _ = X.shape
    kXX = ensemble_block_kernel(X, X, kernel=kernel)
    diag_blocks = get_diag_blocks(kXX, n_ens)
    return (diag_blocks.sum()-kXX.diagonal().sum())/(n_ens*n_samp*(n_samp-1)) - avg_off_diag_blocks(kXX, diag_blocks)

def dist_corr(X, Y, kernel=lambda x, y: metrics.pairwise.rbf_kernel(x, y, gamma=None)):
    """
    Distributional correlation generated by kernel between samples X and Y.
    We use a biased (but consistent) estimator of the variance terms to guarantee
    that the output is in [0, 1].
    An unbiased estimator of the correlation does not exist, so using a biased one
    for the variance terms should not matter.

    Parameters
    ----------
    X : ndarray of shape (n_ensemble, n_samples, n_features)

    Y : ndarray of shape (n_ensemble, n_samples, n_features)
    """
    var_X = dist_cov(X, X, kernel=kernel)
    var_Y = dist_cov(Y, Y, kernel=kernel)
    cov_XY = dist_cov(X, Y, kernel=kernel)
    return cov_XY / np.sqrt(var_X * var_Y)

def kernel_error(X, Y, kernel=lambda x, y: metrics.pairwise.rbf_kernel(x, y, gamma=None)):
    """kernel score using kernel (e.g., rbf (gaussian) kernel k(x,y) = exp(-gamma * ||x-y||^2 / 2))
    n_sample1 >= 2
    n_sample2 >= 1
    
    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]
        kernel -- pairwise kernel from sklearn.metrics

    Keyword Arguments:
        gamma {float} -- [kernel parameter] (default: {1.0})

    Returns:
        [scalar] -- [negative kernel score value]
    """
    XX = kernel(X, X)
    XY = kernel(X, Y)
    # length of X
    n = XX.shape[0]
    return (XX.sum() - XX.diagonal().sum())/(n*(n-1)) - 2 * XY.mean()

def kernel_noise(Y, kernel=lambda x, y: metrics.pairwise.rbf_kernel(x, y, gamma=None)):
    """
    Kernel entropy, unbiased estimator of -||P||_k^2 with Y~P 
    """
    YY = kernel(Y, Y)
    # length of Y
    n = YY.shape[0]
    return (YY.diagonal().sum() - YY.sum())/(n*(n-1))

def sMMD(X, Y, kernel=lambda x, y: metrics.pairwise.rbf_kernel(x, y, gamma=None)):
    """squared MMD using kernel (e.g., rbf (gaussian) kernel k(x,y) = exp(-gamma * ||x-y||^2 / 2))
    unbiased estimator
    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]
        kernel -- pairwise kernel from sklearn.metrics

    Keyword Arguments:
        gamma {float} -- [kernel parameter] (default: {1.0})

    Returns:
        [scalar] -- [MMD value]
    """
    return kernel_error(X, Y, kernel) - kernel_noise(Y, kernel)

# not used
def Var_inner_term(kernel_block_matrix, n_ensemble):
    """Average the non-diagonal entries in the n x n kernel matrices K_ii of the (n*m)x(n*m) kernel block matrix K with i=1...m,
    i.e. avg({K_ii}_jt) with j =/= t, where n = n_samples and m = n_ensemble

    Parameters
    ----------
    kernel_block_matrix : ndarray of shape (n_ensemble * n_sample, n_ensemble * n_sample)
    """
    n_all = kernel_block_matrix.shape[0]
    n_samples = int(n_all / n_ensemble)

    sum_all = np.sum([
        # select n x n matrix K_ii
        kernel_block_matrix[i*n_samples:(i+1)*n_samples, i*n_samples:(i+1)*n_samples] for i in range(n_ensemble)
    ])
    sum_diag = np.sum(kernel_block_matrix.diagonal())
    return (sum_all - sum_diag) / (n_samples*(n_samples-1)*n_ensemble)

# not used
def Cov_inner_term(kernel_block_matrix, n_ensemble):
    """Average the entries in the n x n kernel matrices K_ii of the (n*m)x(n*m) kernel block matrix K with i=1...m,
    i.e. avg({K_ii}_jt), where n = n_samples and m = n_ensemble

    Parameters
    ----------
    kernel_block_matrix : ndarray of shape (n_ensemble * n_sample, n_ensemble * n_sample)
    """
    n_all = kernel_block_matrix.shape[0]
    n_samples = int(n_all / n_ensemble)
    return np.mean([
        # select n x n matrix K_ii
        kernel_block_matrix[i*n_samples:(i+1)*n_samples, i*n_samples:(i+1)*n_samples] for i in range(n_ensemble)
    ])

# not used
def Cov_outer_term(kernel_block_matrix, n_ensemble):
    """
    Average the entries in the m x m kernel matrices K_ij (i =/= j) of the (n*m)x(n*m) kernel block matrix K with i,j=1...n,
    i.e. avg({K_ij}_st) with i =/= j, where m = n_samples and n = n_ensemble

    Parameters
    ----------
    kernel_block_matrix : ndarray of shape (n_ensemble * n_sample, n_ensemble * n_sample)
    """
    n_all = kernel_block_matrix.shape[0]
    n_samples = int(n_all / n_ensemble)

    return np.mean([
        kernel_block_matrix[i*n_samples:(i+1)*n_samples, j*n_samples:(j+1)*n_samples]
        for i in range(n_ensemble)
        for j in range(n_ensemble)
        if i != j
    ])

# not used; is ~20% slower than the other implementation
def dist_var_(X, kernel=lambda x, y: metrics.pairwise.rbf_kernel(x, y, gamma=None)):
    """
    Distributional variance generated by kernel and estimated on samples X

    Parameters
    ----------
    X : ndarray of shape (n_ensemble, n_samples, n_features)
    """
    n_ens, _, _ = X.shape
    kXX = ensemble_block_kernel(X, X, kernel=kernel)
    return Var_inner_term(kXX, n_ensemble=n_ens) - Cov_outer_term(kXX, n_ensemble=n_ens)

# not used; is ~20% slower than the other implementation
def dist_cov_(X, Y, kernel=lambda x, y: metrics.pairwise.rbf_kernel(x, y, gamma=None)):
    """
    Distributional covariance generated by kernel between samples X and Y

    Parameters
    ----------
    X : ndarray of shape (n_ensemble, n_samples, n_features)

    Y : ndarray of shape (n_ensemble, n_samples, n_features)
    """
    n_ens, _, _ = X.shape
    kXY = ensemble_block_kernel(X, Y, kernel=kernel)
    return Cov_inner_term(kXY, n_ensemble=n_ens) - Cov_outer_term(kXY, n_ensemble=n_ens)